{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NLP Code Along\n",
    "\n",
    "For this code along we will build a spam filter! We'll use the various NLP tools we learned about as well as a new classifier, Naive Bayes.\n",
    "\n",
    "We'll use a classic dataset for this - UCI Repository SMS Spam Detection: https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "spark = SparkSession.builder.appName('nlp').getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = spark.read.csv(\"smsspamcollection/SMSSpamCollection\",inferSchema=True,sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|class|                text|\n",
      "+-----+--------------------+\n",
      "|  ham|Go until jurong p...|\n",
      "|  ham|Ok lar... Joking ...|\n",
      "| spam|Free entry in 2 a...|\n",
      "|  ham|U dun say so earl...|\n",
      "|  ham|Nah I don't think...|\n",
      "| spam|FreeMsg Hey there...|\n",
      "|  ham|Even my brother i...|\n",
      "|  ham|As per your reque...|\n",
      "| spam|WINNER!! As a val...|\n",
      "| spam|Had your mobile 1...|\n",
      "|  ham|I'm gonna be home...|\n",
      "| spam|SIX chances to wi...|\n",
      "| spam|URGENT! You have ...|\n",
      "|  ham|I've been searchi...|\n",
      "|  ham|I HAVE A DATE ON ...|\n",
      "| spam|XXXMobileMovieClu...|\n",
      "|  ham|Oh k...i'm watchi...|\n",
      "|  ham|Eh u remember how...|\n",
      "|  ham|Fine if thats th...|\n",
      "| spam|England v Macedon...|\n",
      "+-----+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean and Prepare the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "** Create a new length feature: **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = data.withColumn('length',length(data['text']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+------+\n",
      "|class|                text|length|\n",
      "+-----+--------------------+------+\n",
      "|  ham|Go until jurong p...|   111|\n",
      "|  ham|Ok lar... Joking ...|    29|\n",
      "| spam|Free entry in 2 a...|   155|\n",
      "|  ham|U dun say so earl...|    49|\n",
      "|  ham|Nah I don't think...|    61|\n",
      "| spam|FreeMsg Hey there...|   147|\n",
      "|  ham|Even my brother i...|    77|\n",
      "|  ham|As per your reque...|   160|\n",
      "| spam|WINNER!! As a val...|   157|\n",
      "| spam|Had your mobile 1...|   154|\n",
      "|  ham|I'm gonna be home...|   109|\n",
      "| spam|SIX chances to wi...|   136|\n",
      "| spam|URGENT! You have ...|   155|\n",
      "|  ham|I've been searchi...|   196|\n",
      "|  ham|I HAVE A DATE ON ...|    35|\n",
      "| spam|XXXMobileMovieClu...|   149|\n",
      "|  ham|Oh k...i'm watchi...|    26|\n",
      "|  ham|Eh u remember how...|    81|\n",
      "|  ham|Fine if thats th...|    56|\n",
      "| spam|England v Macedon...|   155|\n",
      "+-----+--------------------+------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----------------+\n",
      "|class|      avg(length)|\n",
      "+-----+-----------------+\n",
      "|  ham|71.48663212435233|\n",
      "| spam|138.6706827309237|\n",
      "+-----+-----------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Pretty Clear Difference\n",
    "data.groupby('class').mean().show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature Transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import Tokenizer,StopWordsRemover, CountVectorizer,IDF,StringIndexer\n",
    "\n",
    "tokenizer = Tokenizer(inputCol=\"text\", outputCol=\"token_text\")\n",
    "stopremove = StopWordsRemover(inputCol='token_text',outputCol='stop_tokens')\n",
    "count_vec = CountVectorizer(inputCol='stop_tokens',outputCol='c_vec')\n",
    "idf = IDF(inputCol=\"c_vec\", outputCol=\"tf_idf\")\n",
    "ham_spam_to_num = StringIndexer(inputCol='class',outputCol='label')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.linalg import Vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "clean_up = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Model\n",
    "\n",
    "We'll use Naive Bayes, but feel free to play around with this choice!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import NaiveBayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Use defaults\n",
    "nb = NaiveBayes()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_prep_pipe = Pipeline(stages=[ham_spam_to_num,tokenizer,stopremove,count_vec,idf,clean_up])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "cleaner = data_prep_pipe.fit(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "clean_data = cleaner.transform(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and Evaluation!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "clean_data = clean_data.select(['label','features'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|label|            features|\n",
      "+-----+--------------------+\n",
      "|  0.0|(13461,[8,12,33,6...|\n",
      "|  0.0|(13461,[0,26,308,...|\n",
      "|  1.0|(13461,[2,14,20,3...|\n",
      "|  0.0|(13461,[0,73,84,1...|\n",
      "|  0.0|(13461,[36,39,140...|\n",
      "|  1.0|(13461,[11,57,62,...|\n",
      "|  0.0|(13461,[11,55,108...|\n",
      "|  0.0|(13461,[133,195,4...|\n",
      "|  1.0|(13461,[1,50,124,...|\n",
      "|  1.0|(13461,[0,1,14,29...|\n",
      "|  0.0|(13461,[5,19,36,4...|\n",
      "|  1.0|(13461,[9,18,40,9...|\n",
      "|  1.0|(13461,[14,32,50,...|\n",
      "|  0.0|(13461,[42,99,101...|\n",
      "|  0.0|(13461,[567,1744,...|\n",
      "|  1.0|(13461,[32,113,11...|\n",
      "|  0.0|(13461,[86,224,37...|\n",
      "|  0.0|(13461,[0,2,52,13...|\n",
      "|  0.0|(13461,[0,77,107,...|\n",
      "|  1.0|(13461,[4,32,35,6...|\n",
      "+-----+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "clean_data.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "(training,testing) = clean_data.randomSplit([0.7,0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "spam_predictor = nb.fit(training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- label: string (nullable = true)\n",
      " |-- text: string (nullable = true)\n",
      " |-- length: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_results = spam_predictor.transform(testing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "|label|            features|       rawPrediction|         probability|prediction|\n",
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "|  0.0|(13461,[0,1,2,14,...|[-612.34877984332...|[0.99999999999999...|       0.0|\n",
      "|  0.0|(13461,[0,1,5,15,...|[-742.97388469249...|[1.0,5.5494439698...|       0.0|\n",
      "|  0.0|(13461,[0,1,6,16,...|[-1004.8197043274...|[1.0,5.0315468936...|       0.0|\n",
      "|  0.0|(13461,[0,1,12,34...|[-879.22017540506...|[1.0,1.0023148863...|       0.0|\n",
      "|  0.0|(13461,[0,1,15,33...|[-216.47131414494...|[1.0,1.1962236837...|       0.0|\n",
      "|  0.0|(13461,[0,1,16,21...|[-673.71050817005...|[1.0,1.5549413147...|       0.0|\n",
      "|  0.0|(13461,[0,1,22,26...|[-382.58333036006...|[1.0,1.7564627587...|       0.0|\n",
      "|  0.0|(13461,[0,1,25,66...|[-1361.5572580867...|[1.0,2.1016772175...|       0.0|\n",
      "|  0.0|(13461,[0,1,33,46...|[-378.04433557629...|[1.0,6.5844301586...|       0.0|\n",
      "|  0.0|(13461,[0,1,156,1...|[-251.74061863695...|[0.88109389963478...|       0.0|\n",
      "|  0.0|(13461,[0,1,510,5...|[-325.61601503458...|[0.99999999996808...|       0.0|\n",
      "|  0.0|(13461,[0,1,896,1...|[-96.594570068189...|[0.99999996371517...|       0.0|\n",
      "|  0.0|(13461,[0,2,3,6,7...|[-2547.2759643071...|[1.0,4.6246337876...|       0.0|\n",
      "|  0.0|(13461,[0,2,4,6,8...|[-998.45874047729...|[1.0,1.0141354676...|       0.0|\n",
      "|  0.0|(13461,[0,2,4,9,1...|[-1316.5960403060...|[1.0,4.5097599797...|       0.0|\n",
      "|  0.0|(13461,[0,2,4,28,...|[-430.36443439941...|[1.0,1.3571829152...|       0.0|\n",
      "|  0.0|(13461,[0,2,5,9,7...|[-743.88098750283...|[1.0,1.2520543556...|       0.0|\n",
      "|  0.0|(13461,[0,2,5,15,...|[-1081.6827153914...|[1.0,2.5678822847...|       0.0|\n",
      "|  0.0|(13461,[0,2,5,25,...|[-833.43388596187...|[0.99999999865682...|       0.0|\n",
      "|  0.0|(13461,[0,2,5,25,...|[-492.43437000641...|[1.0,2.6689876137...|       0.0|\n",
      "+-----+--------------------+--------------------+--------------------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_results.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of model at predicting spam was: 0.9248020435242028\n"
     ]
    }
   ],
   "source": [
    "acc_eval = MulticlassClassificationEvaluator()\n",
    "acc = acc_eval.evaluate(test_results)\n",
    "print(\"Accuracy of model at predicting spam was: {}\".format(acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad considering we're using straight math on text data! Try switching out the classification models! Or even try to come up with other engineered features!\n",
    "\n",
    "## Great Job!"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
